import torch
import torch.nn.functional as F
from einops import einsum, repeat

import tilelang as tl
import tilelang.language as T
from typing import Optional
from index import prepare_token_indices

from utils import get_abs_err, get_err_ratio

BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"

pass_configs = {
    tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}


@tl.jit(pass_configs=pass_configs)
def tl_indexer_bwd_impl(
    heads: int,
    dim: int,
    topk: int,
    sm_scale: Optional[float] = None,
    block_I: int = 32,
    num_stages: int = 0,
    num_threads: int = 128,
):
    assert num_stages == 0
    assert topk == tl.math.next_power_of_2(topk)
    assert topk % block_I == 0
    assert heads <= 64 and heads % 8 == 0
    batch_plus_one = T.symbolic("batch_plus_one")
    seq_len = T.symbolic("seq_len")
    dtype: str = BF16
    accum_dtype: str = FP32
    index_q_shape = [seq_len, heads, dim]
    weights_shape = [seq_len, heads]
    index_k_shape = [seq_len, dim]
    shape_p = [seq_len, topk]
    topk_indices_shape = [seq_len, topk]
    offsets_shape = [batch_plus_one]
    token_indices_shape = [seq_len, 2]
    if sm_scale is None:
        sm_scale = dim**-0.5

    @T.prim_func
    def tl_indexer_bwd_kernel(
            IndexQ: T.Tensor(index_q_shape, dtype),
            Weights: T.Tensor(weights_shape, dtype),
            IndexK: T.Tensor(index_k_shape, dtype),
            dIndexQ: T.Tensor(index_q_shape, dtype),
            dWeights: T.Tensor(weights_shape, dtype),
            dIndexK: T.Tensor(index_k_shape, dtype),
            AttnScore: T.Tensor(shape_p, FP32),
            IndexScore: T.Tensor(shape_p, FP32),
            TopkIndices: T.Tensor(topk_indices_shape, INT32),
            Offsets: T.Tensor(offsets_shape, INT32),
            TokenIndices: T.Tensor(token_indices_shape, INT32),
    ):
        with T.Kernel(seq_len, threads=num_threads) as (bx):
            i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
            bos = Offsets[i_b]
            num_blocks = T.ceildiv(topk, block_I)

            index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
            weights_shared = T.alloc_shared([heads], dtype=dtype)

            d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype)
            d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype)

            T.copy(IndexQ[bos + i_t, :, :], index_q_shared)
            T.copy(Weights[bos + i_t, :], weights_shared)
            T.fill(d_index_q_frag, 0)
            T.fill(d_weights_frag, 0)

            for i, j in T.Parallel(heads, dim):
                index_q_shared[i, j] = index_q_shared[i, j] * sm_scale

            for bi_i in T.Pipelined(num_blocks, num_stages=num_stages):

                i_st = bi_i * block_I
                i_ed = (bi_i + 1) * block_I

                indices_shared = T.alloc_shared([block_I], dtype=INT32)
                T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared)

                index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype)
                for i, j in T.Parallel(block_I, dim):
                    pos = indices_shared[i]
                    index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t),
                                                          IndexK[bos + pos, j], 0)

                attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
                index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
                for i in T.Parallel(block_I):
                    attn_score_shared[i] = AttnScore[bos + i_t, i_st + i]
                    index_score_shared[i] = IndexScore[bos + i_t, i_st + i]

                logits = T.alloc_fragment((block_I, heads), accum_dtype)
                T.gemm(
                    index_k_shared,
                    index_q_shared,
                    logits,
                    transpose_A=False,
                    transpose_B=True,
                    clear_accum=True,
                )
                for i, j in T.Parallel(block_I, heads):
                    logits[i, j] = T.max(logits[i, j], 0)

                # dw
                d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype)
                for i, j in T.Parallel(block_I, heads):
                    d_weights_i[i,
                                j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j]
                T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False)

                d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype)
                d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype)
                d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype)

                for i, j in T.Parallel(block_I, heads):
                    d_relu = T.alloc_var(accum_dtype)
                    if logits[i, j] > 0:
                        d_relu = 1.0
                    else:
                        d_relu = 0.0
                    d_logits_qk[i, j] = (index_score_shared[i] -
                                         attn_score_shared[i]) * d_relu * weights_shared[j]

                # dq
                T.copy(d_logits_qk, d_logits_qk_cast1)
                T.gemm(
                    d_logits_qk_cast1,  # [BS, HQ]
                    index_k_shared,  # [BS, K]
                    d_index_q_frag,  # [HQ, K]
                    transpose_A=True,
                    transpose_B=False,
                    clear_accum=False,
                )

                # dk
                T.copy(d_logits_qk, d_logits_qk_cast2)
                d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype)
                T.gemm(
                    d_logits_qk_cast2,  # [BS, HQ]
                    index_q_shared,  # [HQ, K]
                    d_index_k_frag,  # [BS, K]
                    transpose_A=False,
                    transpose_B=False,
                    clear_accum=True,
                )

                for i, j in T.Parallel(block_I, dim):
                    pos = indices_shared[i]
                    if ((pos > -1) & (pos <= i_t)):
                        T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j])

            for i, j in T.Parallel(heads, dim):
                d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale

            T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :])
            T.copy(d_weights_frag, dWeights[bos + i_t, :])

    return tl_indexer_bwd_kernel


def indexer_bwd_interface(
    q: torch.Tensor,
    weights: torch.Tensor,
    k: torch.Tensor,
    attn_score: torch.Tensor,
    index_score: torch.Tensor,
    topk_indices: torch.Tensor,
    offsets: torch.Tensor,
):
    _, heads, dim, topk = *q.shape, topk_indices.shape[-1]
    token_indices = prepare_token_indices(offsets)
    dq = torch.zeros_like(q)
    dweights = torch.zeros_like(weights)
    dk = torch.zeros_like(k)
    kernel = tl_indexer_bwd_impl(heads, dim, topk)
    kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets,
           token_indices)
    return dq, dweights, dk


def ref_indexer_bwd(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor,
                    TopkIndices: torch.Tensor, AttnScore: torch.Tensor,
                    offsets: torch.Tensor) -> torch.Tensor:
    Q.requires_grad_(True)
    Weights.requires_grad_(True)
    K.requires_grad_(True)
    softmax_scale = Q.shape[-1]**-0.5
    all_loss = []
    all_log_topk_prob = []
    for i in range(offsets.shape[0] - 1):
        assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1]
        q = Q[offsets[i]:offsets[i + 1]]
        weights = Weights[offsets[i]:offsets[i + 1]]
        k = K[offsets[i]:offsets[i + 1]]
        topk_indices = TopkIndices[offsets[i]:offsets[i + 1]]
        attn_score = AttnScore[offsets[i]:offsets[i + 1]]
        s = q.shape[0]
        mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
        logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') * softmax_scale
        logits = F.relu(logits)
        score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32)
        score = torch.where(mask, score, float('-inf'))
        topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64))
        log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32)
        loss = F.kl_div(
            log_topk_prob.clip(-100, 0),
            attn_score.log().clip(-100, 0),
            log_target=True,
            reduction="sum")
        all_loss.append(loss)
        all_log_topk_prob.append(log_topk_prob)
    loss = torch.stack(all_loss).sum()
    loss.backward()
    log_topk_prob = torch.cat(all_log_topk_prob, dim=0)
    return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad


def test_kernel(
    B=1,
    S=2048,
    H=16,
    D=128,
    topk=64,
):
    torch.manual_seed(42)
    q = torch.randn((S, H, D)).cuda().bfloat16()
    w = torch.randn((S, H)).cuda().bfloat16()
    k = torch.randn((S, D)).cuda().bfloat16()
    offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()

    all_attn_score = []
    for i in range(offsets.shape[0] - 1):
        seq_len = (offsets[i + 1] - offsets[i]).item()
        mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device)
        logits = torch.ones(seq_len, topk).cuda()
        logits = torch.where(mask, logits, float('-inf'))
        attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
        all_attn_score.append(attn_score)
    attn_score = torch.cat(all_attn_score, dim=0)

    topk_indices = repeat(
        torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous()
    index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score,
                                                          offsets)

    dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets)

    print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}")
    print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}")
    print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}")


if __name__ == '__main__':
    test_kernel()
